/*
  Copyright 2006-2012 Patrik Jonsson, sunrise@familjenjonsson.org

  This file is part of Sunrise.

  Sunrise is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 3 of the License, or
  (at your option) any later version.

  Sunrise is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with Sunrise.  If not, see <http://www.gnu.org/licenses/>.
  
*/

#include "analytic_case.h"
#include "blackbody.h"
#include "blackbody_grain.h"
#include "misc.h"
#include "grain_model.h"
#include <fstream>

class pascucci_factory {
public:
  typedef T_grid::T_grid_impl::T_data T_data;
  typedef refinement_accuracy_data<T_data> T_racc;
  typedef cell_tracker<T_data> T_cell_tracker;

  const T_densities rho0_, n_tol_;
  const T_float ri, ro, rd, zd, r_tol_;
  const int maxl, n_lambda, n_threads_;
  
  pascucci_factory (const T_densities& rho0, const T_densities& n_tol,
		    T_float r_tol,
		    int ml, int nl, int nt):
    ri(1), ro(1000), rd (ro/2), zd(ro/8),
    rho0_(rho0), n_tol_(n_tol), r_tol_(r_tol),
    maxl(ml), n_lambda (nl), n_threads_(nt) {};

  T_densities calc_rho(const vec3d& p) {
    T_densities rho;
    resize_like(rho, rho0_);
    rho = 0;
    const T_float r = sqrt(p[0]*p[0] + p[1]*p[1]); 
    const T_float z = p[2];
    if ((r < ro) && (r > ri)) {
      // we are in the allowed range of radii, calculate the density
      const T_float h=zd*pow(r/rd,1.125);
      const T_float f2=exp(-constants::pi/4*pow(z/h,2));
      rho = rho0_ *(rd/r)*f2;
    }
    return rho;
  };

  bool refine_cell_p (const T_cell_tracker& c) {
    const int level = c.code().level();
    if (level>maxl) 
      return false;
    if (level<=4) return true; // a minimum amount of "prerefinement"
    // refine until cell column density < n_tol

    // simone uses maxlevel=10+log(r/rmin)
    const vec3d min = c.getmin();
    const vec3d max = c.getmax();
    /*    
    T_densities rhoav
      ((calc_rho(c.getcenter())+
	calc_rho(min)+
	calc_rho(vec3d(min[0],min[1],max[2]))+
	calc_rho(vec3d(min[0],max[1],min[2]))+
	calc_rho(vec3d(min[0],max[1],max[2]))+
	calc_rho(vec3d(max[0],min[1],min[2]))+
	calc_rho(vec3d(max[0],min[1],max[2]))+
	calc_rho(vec3d(max[0],max[1],min[2]))+
	calc_rho(max))/9);
    */
    const T_float sz=sqrt(dot(c.getsize(),c.getsize()));
    T_densities rhoav
      (-log((exp(-sz*calc_rho(c.getcenter()))+
	     exp(-sz*calc_rho(min))+
	     exp(-sz*calc_rho(vec3d(min[0],min[1],max[2])))+
	     exp(-sz*calc_rho(vec3d(min[0],max[1],min[2])))+
	     exp(-sz*calc_rho(vec3d(min[0],max[1],max[2])))+
	     exp(-sz*calc_rho(vec3d(max[0],min[1],min[2])))+
	     exp(-sz*calc_rho(vec3d(max[0],min[1],max[2])))+
	     exp(-sz*calc_rho(vec3d(max[0],max[1],min[2])))+
	     exp(-sz*calc_rho(max)))/9)/sz) ;
    
    // refine cell if cellsize/distance from source > r_tol_ or if it
    // is within one cellsize of the boundary (assuming some of the
    // densities are nonzero)
    vec3d cylr=c.getcenter();
    cylr[2]=0;
    if ((any(rhoav)>0) && 
	//( (dot(c.getsize(), c.getsize()/dot(c.getcenter(), c.getcenter())) >
	( (dot(c.getsize(), c.getsize()/dot(cylr,cylr)) >
	   r_tol_*r_tol_) ||
	  (2*abs(sqrt(dot(cylr,cylr)-ri))<
	   sqrt(dot(c.getsize(),c.getsize())))
	  ))
      return true;

    // refine cell if column density exceeds n_tol
    // length required for tau 1
    //T_densities abslen(64*n_tol_/rhosum);
    //if (condition_all(abslen>50))
      return condition_all (rhoav*rhoav*dot(c.getsize(),c.getsize()) > 
			    n_tol_*n_tol_);
      //else
      // hopelessly optically thick, just go by the radial criterion
      //return false;
  };
  bool unify_cell_p (const T_cell_tracker& c,
		     const T_racc& racc) {
    return false;
    /*
    vec3d center = begin->getcenter();
    T_float vol = begin->volume();
    T_densities s ((begin++)->data()->get_absorber().opacity());
    T_densities ssq (s*s);
    int n=1;
    while (begin != end) {
      center += begin->getcenter();
      vol += begin->volume();
      T_densities o = (begin++)->data()->get_absorber().opacity();
      s+=o;
      ssq+=o*o;
      n++;
    }

    const T_densities mean (s/n);
    const T_densities stdev2 (ssq/n- mean*mean);
    const T_float r2=dot(center,center)/(n*n);
    return condition_all ((mean==0) || ((stdev2/(mean*mean))<tol) || 
			  (stdev2<r2*r2/(vol*vol)*tabs*tabs));
    */
  };

  T_data get_data (const T_cell_tracker& c) {
    const vec3d min = c.getmin();
    const vec3d max = c.getmax();
    /*
    T_densities rhoav
      ((calc_rho(c.getcenter())+
	calc_rho(min)+
	calc_rho(vec3d(min[0],min[1],max[2]))+
	calc_rho(vec3d(min[0],max[1],min[2]))+
	calc_rho(vec3d(min[0],max[1],max[2]))+
	calc_rho(vec3d(max[0],min[1],min[2]))+
	calc_rho(vec3d(max[0],min[1],max[2]))+
	calc_rho(vec3d(max[0],max[1],min[2]))+
	calc_rho(max))/9);
    */
    const T_float sz=sqrt(dot(c.getsize(),c.getsize()));
    T_densities rhoav
      (-log((exp(-sz*calc_rho(c.getcenter()))+
	     exp(-sz*calc_rho(min))+
	     exp(-sz*calc_rho(vec3d(min[0],min[1],max[2])))+
	     exp(-sz*calc_rho(vec3d(min[0],max[1],min[2])))+
	     exp(-sz*calc_rho(vec3d(min[0],max[1],max[2])))+
	     exp(-sz*calc_rho(vec3d(max[0],min[1],min[2])))+
	     exp(-sz*calc_rho(vec3d(max[0],min[1],max[2])))+
	     exp(-sz*calc_rho(vec3d(max[0],max[1],min[2])))+
	     exp(-sz*calc_rho(max)))/9)/sz) ;
    //const T_densities rhoav = calc_rho(c.getcenter());
    return T_data (T_data::T_emitter(),
		   T_data::T_absorber(rhoav, vec3d(0,0,0), T_lambda (n_lambda)));
  }
    
  int n_threads () {return n_threads_;};
};

T_float unit_opacity_rho;

void setup (po::variables_map& opt)
{
  T_unit_map units;
  units ["length"] = "au";
  units ["mass"] = "Msun";
  units ["luminosity"] = "W";
  units ["wavelength"] = "m";
  units ["L_lambda"] = "W/m";

  const int n_threads = opt["n_threads"].as<int>();
  const int maxlevel = opt["maxlevel"].as<int>();
  // make grid extent larger in xy plane to get flattened cells
  const vec3d grid_extent(4000,4000,1000);
  const T_float cameradist = 1e5;
  const int npix = opt["npix"].as<int>();
  const T_float r_tol =opt["r_tol"].as<T_float>();
  const T_float tau = opt["tau"].as<T_float>();
  const T_float tau_tol = opt["tau_tol"].as<T_float>();
  const string grain_dir = word_expand(opt["grain_data_dir"].as<string>())[0];
  const string grain_model = opt["grain_model"].as<string>();

  // opacity for the WD01 dust = 2.41e-6 kpc^2/Msun (=1.025e11 au^2/Msun)
  // for pascucci grid, n=rho0*rd*6.9078
  // which implies tau = 3.54e14 au^3/Msun * rho0

  // read wavelengths (wavelength options are ignored in this case)
  std::ifstream lfile;
  if(grain_model=="p04")
    lfile.open("pascucci_v2.wavelengths");
  else
    lfile.open("pascucci.wavelengths");
  std::vector<T_float> lambdas((std::istream_iterator<float>(lfile)),
			       std::istream_iterator<float>());
  array_1 lambda(&lambdas[0], blitz::TinyVector<int,1>(lambdas.size()), 
		 blitz::neverDeleteData);
  const int lambda550 = lower_bound(lambdas.begin(), lambdas.end(),
				    0.551e-6) - lambdas.begin()-1;
  cout << "550nm is lambda entry " << lambda550 << endl;

  // we need a preferences object to pass parameters to grain model
  Preferences p;
  p.setValue("grain_data_directory", grain_dir);
  p.setValue("grain_model", grain_model);
  p.setValue("wd01_parameter_set", opt["wd01_parameter_set"].as<string>());
  p.setValue("template_pah_fraction", 0.5);
  p.setValue("use_dl07_opacities", true);
  p.setValue("n_threads", n_threads);
  p.setValue("use_grain_temp_lookup", true);
  p.setValue("use_cuda", false);

  
  // load grain model
  T_dust_model::T_scatterer_vector sv;
  sv.push_back(load_grain_model<polychromatic_scatterer_policy, mcrx_rng_policy>(p,units));
  T_dust_model model(sv);
  model.set_wavelength(lambda); // this also resamples the grain_model objects

  // Convert optical depth to density
  unit_opacity_rho = 1.0/model.get_scatterer(0).opacity()(lambda550);
  T_densities rho(1);
  rho = tau*unit_opacity_rho/(500*6.9078);
  T_densities n_tol(1); n_tol = tau_tol*unit_opacity_rho;
  cout << "Unit opacity density: " << unit_opacity_rho
       <<  ", n_tol: " << n_tol(0) << endl;

  // build grid
  pascucci_factory factory (rho,n_tol, r_tol, maxlevel, 
			    lambda.size(), n_threads);

  blackbody bb(5800,3.91e26);
  boost::shared_ptr<T_emission> em
    (new pointsource_emission<polychromatic_policy, T_rng_policy> 
     (vec3d (.0, .0, .0), bb.emission(lambda)));
  
  cout << "setting up emergence" << endl;
  std::vector<std::pair<T_float, T_float> > cam_pos;
  cam_pos.push_back(std:: make_pair (12.5*constants::pi/180, 0.1) );
  cam_pos.push_back(std:: make_pair (42.5*constants::pi/180, 0.1) );
  cam_pos.push_back(std:: make_pair (77.5*constants::pi/180, 0.1) );
  cam_pos.push_back(std:: make_pair (65.5*constants::pi/180, 0.1) );
  cam_pos.push_back(std:: make_pair (88.5*constants::pi/180, 0.1) );
  //cam_pos.push_back(std:: make_pair (77.5*constants::pi/180, .393) );
  //cam_pos.push_back(std:: make_pair (77.5*constants::pi/180, .785) );
  boost::shared_ptr<T_emergence> e
    (new T_emergence(cam_pos, cameradist, 2.3*grid_extent[2], npix));

  run_case(opt,
	   factory,
	   units,
	   lambda,
	   -grid_extent,
	   grid_extent,
	   *em,
	   *e,
	   model);
}

void pre_shoot(T_xfer& x)
{
  T_densities n_col(x.integrate_column_density(vec3d(0,0,0),vec3d(1,0,0)));
  std::cout << "Actual V-band optical depth in disk midplane: " << n_col(0)/unit_opacity_rho << std::endl;
}

std::string description()
{
  return "Pascucci et al 04 benchmark.";
}

void add_custom_options(po::options_description& desc)
{
  desc.add_options()
    ("r_tol", po::value<T_float>()->default_value(0.2),
     "cell radial extent tolerance")
    ("tau", po::value<T_float>()->default_value(100),
     "edge-on V-band optical depth of disk")
    ("grain_model", po::value<string>()->default_value("p04"), "dust grain model")
    ;
}

void print_custom_options(const po::variables_map& opt)
{
  cout
       << "r_tol = " << opt["r_tol"].as<T_float>() << '\n'
       << "tau = " << opt["tau"].as<T_float>() << '\n'
       << "grain_model = " << opt["grain_model"].as<std::string>() << '\n'
    ;
}
